<!DOCTYPE html>
<html>

<head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm/dist/tf-backend-wasm.js"></script>
    <script>
        tf.setBackend('wasm');
        var _hmt = _hmt || [];
        (function () {
            var hm = document.createElement("script");
            hm.src = "https://hm.baidu.com/hm.js?1e834a9b11dc71db3d0ae4cbb885253d";
            var s = document.getElementsByTagName("script")[0];
            s.parentNode.insertBefore(hm, s);
        })();
    </script>
</head>

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <meta http-equiv="X-UA-Compatible" content="ie=edge">
    <title>RWKV-v2-RNN Playground (Language Model)</title>
</head>
<style>
    * {
        margin: 0px;
        padding: 0px;
    }

    body {
        background-color: #ffd;
        overflow-y: scroll;
        color: #000;
        font-family: sans-serif;
        margin: 0.2em;
    }

    a:link {
        color: inherit;
    }

    a {
        color: inherit;
    }

    a:hover {
        color: #f00 !important;
    }

    a:visited {
        color: inherit;
    }

    #textArea,
    #logArea {
        font-family: monospace;
        line-height: 1.5;
        font-size: 100%;
        height: 75vh;
        padding: 0.3em;
        width: 100%;
        box-sizing: border-box;
    }

    #logArea {
        font-size: 70%;
        background-color: #ddd;
    }

    #textWrap {
        margin: 0 auto;
    }

    .writeBtn {
        width: 45%;
        height: 2em;
        cursor: pointer;
    }
</style>

<body>
    <div style="text-align: center; line-height: 1.5; font-size: 75%;">
        <p><a href="https://github.com/BlinkDL/RWKV-LM" target="_blank">RWKV-v2-RNN Playground (Language Model)</a> - JS
            inference in your browser - 26M params char-level model (L6-D512) trained on PG-19 dataset</p>
        <p style="color:blue; font-weight: bold;">NOTE: The PG-19 dataset has a line width of 72, thus the model will
            insert line breaks. Use the same format for your prompt.</p>
    </div>
    <div id="textWrap">
        <div style="height:3em">
            <div id="loading" style="color:blue; font-size: 100%;text-align: center;">
                Loading... <span id="loadprog"></span>
            </div>
            <div id="writeBtns" style="width: 100%; display: none; margin-bottom:0.5em; flex-direction:row; justify-content:space-evenly;">
                <button class="writeBtn" onclick="sendText()">Continue story (alt+Q)</button>
                <button class="writeBtn" onclick="rewriteText()">Alternative (alt+E)</button>
            </div>
        </div>
        <div style="display:flex; flex-direction:row; justify-content:space-evenly;">
            <textarea id="textArea" style="flex: 2 1 0">In the</textarea>
            <textarea id="logArea" style="flex: 1 1 0">Write in the left textarea. Here is the log of all generation results.</textarea>
        </div>
    </div>
</body>
<script>
    "use strict";
    const gParam = {}

    const n_layer = 6
    const n_embd = 512
    const ctx_len = 1024
    const vocab_size = 3506

    var weightName = [
        'emb.weight',
        'blocks.0.ln1.weight',
        'blocks.0.ln1.bias',
        'blocks.0.ln2.weight',
        'blocks.0.ln2.bias',
        'blocks.0.att.time_decay',
        'blocks.0.att.time_first',
        'blocks.0.att.time_mix',
        'blocks.0.att.key.weight',
        'blocks.0.att.value.weight',
        'blocks.0.att.receptance.weight',
        'blocks.0.att.output.weight',
        'blocks.0.ffn.time_mix',
        'blocks.0.ffn.key.weight',
        'blocks.0.ffn.receptance.weight',
        'blocks.0.ffn.value.weight',
        'blocks.1.ln1.weight',
        'blocks.1.ln1.bias',
        'blocks.1.ln2.weight',
        'blocks.1.ln2.bias',
        'blocks.1.att.time_decay',
        'blocks.1.att.time_first',
        'blocks.1.att.time_mix',
        'blocks.1.att.key.weight',
        'blocks.1.att.value.weight',
        'blocks.1.att.receptance.weight',
        'blocks.1.att.output.weight',
        'blocks.1.ffn.time_mix',
        'blocks.1.ffn.key.weight',
        'blocks.1.ffn.receptance.weight',
        'blocks.1.ffn.value.weight',
        'blocks.2.ln1.weight',
        'blocks.2.ln1.bias',
        'blocks.2.ln2.weight',
        'blocks.2.ln2.bias',
        'blocks.2.att.time_decay',
        'blocks.2.att.time_first',
        'blocks.2.att.time_mix',
        'blocks.2.att.key.weight',
        'blocks.2.att.value.weight',
        'blocks.2.att.receptance.weight',
        'blocks.2.att.output.weight',
        'blocks.2.ffn.time_mix',
        'blocks.2.ffn.key.weight',
        'blocks.2.ffn.receptance.weight',
        'blocks.2.ffn.value.weight',
        'blocks.3.ln1.weight',
        'blocks.3.ln1.bias',
        'blocks.3.ln2.weight',
        'blocks.3.ln2.bias',
        'blocks.3.att.time_decay',
        'blocks.3.att.time_first',
        'blocks.3.att.time_mix',
        'blocks.3.att.key.weight',
        'blocks.3.att.value.weight',
        'blocks.3.att.receptance.weight',
        'blocks.3.att.output.weight',
        'blocks.3.ffn.time_mix',
        'blocks.3.ffn.key.weight',
        'blocks.3.ffn.receptance.weight',
        'blocks.3.ffn.value.weight',
        'blocks.4.ln1.weight',
        'blocks.4.ln1.bias',
        'blocks.4.ln2.weight',
        'blocks.4.ln2.bias',
        'blocks.4.att.time_decay',
        'blocks.4.att.time_first',
        'blocks.4.att.time_mix',
        'blocks.4.att.key.weight',
        'blocks.4.att.value.weight',
        'blocks.4.att.receptance.weight',
        'blocks.4.att.output.weight',
        'blocks.4.ffn.time_mix',
        'blocks.4.ffn.key.weight',
        'blocks.4.ffn.receptance.weight',
        'blocks.4.ffn.value.weight',
        'blocks.5.ln1.weight',
        'blocks.5.ln1.bias',
        'blocks.5.ln2.weight',
        'blocks.5.ln2.bias',
        'blocks.5.att.time_decay',
        'blocks.5.att.time_first',
        'blocks.5.att.time_mix',
        'blocks.5.att.key.weight',
        'blocks.5.att.value.weight',
        'blocks.5.att.receptance.weight',
        'blocks.5.att.output.weight',
        'blocks.5.ffn.time_mix',
        'blocks.5.ffn.key.weight',
        'blocks.5.ffn.receptance.weight',
        'blocks.5.ffn.value.weight',
        'ln_out.weight',
        'ln_out.bias',
        'head.weight',
        'head_q.weight',
        'head_k.weight'
    ]

    var N_PARAMS = weightName.length
    var LOADED_PARAMS = 0

    var request = new XMLHttpRequest()
    var itos = {}
    var stoi = {}
    request.responseType = 'json';
    request.open('GET', "js_model/word-utf8.json", true)
    request.onload = function () {
        itos = request.response
        for (var key in itos) {
            stoi[itos[key]] = parseInt(key)
        }
    }
    request.send()

    function loadWeight(wName) {
        var request = new XMLHttpRequest()
        request.open('GET', "js_model/" + wName + ".bin", true)
        request.responseType = 'blob'
        request.onload = function () {
            var reader = new FileReader()
            reader.readAsArrayBuffer(request.response)
            reader.onload = function (e) {
                var ww = tf.tensor(new Float32Array(reader.result))

                if ((wName == 'emb.weight') || (wName.includes('head')))
                    ww = ww.reshape([-1, n_embd])
                else if (wName.endsWith('key.weight'))
                    ww = ww.reshape([-1, n_embd])
                else if (wName.endsWith('value.weight'))
                    ww = ww.reshape([n_embd, -1])
                else if (wName.endsWith('receptance.weight') || wName.endsWith('output.weight'))
                    ww = ww.reshape([n_embd, n_embd])

                var xx = wName.split('.')
                var here = gParam
                for (var i = 0; i < xx.length; i++) {
                    if (xx[i] == parseInt(xx[i])) {
                        var ii = parseInt(xx[i])
                        if (!(ii in here))
                            here[ii] = {}
                        here = here[ii]
                    } else {
                        if (i == xx.length - 1)
                            here[xx[i]] = ww
                        else if (!(xx[i] in here)) {
                            here[xx[i]] = {}
                        }
                        here = here[xx[i]]
                    }
                }
                LOADED_PARAMS += 1
                document.getElementById("loadprog").innerHTML = Math.round(LOADED_PARAMS / N_PARAMS * 100) + '%'
                if (LOADED_PARAMS == N_PARAMS) {
                    showBtn()
                }
            }
        }
        request.send()
    }

    function hideBtn() {
        document.getElementById("loading").style.display = "block";
        document.getElementById("writeBtns").style.display = "none";
    }

    function showBtn() {
        document.getElementById("loading").style.display = "none";
        document.getElementById("writeBtns").style.display = "flex";
    }

    var ln_eps = null
    var const_1 = null
    var xxx = {}
    var aaa = {}
    var bbb = {}
    var hhh = null

    tf.ready().then(() => {
        ln_eps = tf.tensor(1e-5)
        const_1 = tf.ones([n_embd])
        for (var i = 0; i < N_PARAMS; i++) {
            loadWeight(weightName[i])
        }
    });

    function LayerNorm(x, ln) {
        var x_mean = x.mean()
        var x1 = x.sub(x_mean)
        var x_var = x1.square().mean()
        var x_std = (x_var.add(ln_eps)).sqrt()
        x = x1.div(x_std).mul(ln.weight).add(ln.bias)
        return x
    }

    function FF(xx, w, name) {
        if (!(name in xxx)) {
            xxx[name] = tf.zeros([n_embd])
        }
        var x = xx.mul(w.time_mix).add(xxx[name].mul(const_1.sub(w.time_mix)))
        xxx[name].dispose()
        xxx[name] = tf.keep(xx)
        x = x.expandDims(1)

        var r = w.receptance.weight.matMul(x).sigmoid()
        var k = w.key.weight.matMul(x).relu().square()
        var kv = w.value.weight.matMul(k)
        x = r.mul(kv).squeeze()

        return x
    }

    function SA(xx, w, name) {
        if (!(name in xxx)) {
            xxx[name] = tf.zeros([n_embd])
            aaa[name] = tf.zeros([n_embd])
            bbb[name] = tf.zeros([n_embd])
        }
        var x = xx.mul(w.time_mix).add(xxx[name].mul(const_1.sub(w.time_mix)))
        xxx[name].dispose()
        xxx[name] = tf.keep(xx)
        x = x.expandDims(1)

        var r = w.receptance.weight.matMul(x).sigmoid().squeeze()
        var k = w.key.weight.matMul(x).clipByValue(-999999, 60).exp().squeeze()
        var v = w.value.weight.matMul(x).squeeze()
        var kv = k.mul(v)

        var a = aaa[name].add(w.time_first.mul(kv))
        var b = bbb[name].add(w.time_first.mul(k))
        var aa = aaa[name].clone()
        var bb = bbb[name].clone()
        aaa[name].dispose()
        bbb[name].dispose()
        aaa[name] = tf.keep(w.time_decay.mul(aa).add(kv))
        bbb[name] = tf.keep(w.time_decay.mul(bb).add(k))

        var rwkv = r.mul(a).div(b.add(1e-16)).expandDims(1)
        rwkv = w.output.weight.matMul(rwkv).squeeze()

        return rwkv
    }

    function clearStat() {
        xxx = {}
        aaa = {}
        bbb = {}
        hhh = null
    }

    function saveStat(out, name) {
        ctxBuf[name] = {}
        var buf = ctxBuf[name]
        buf.out = out.slice()
        buf.xxx = {}
        for (var x in xxx) {
            buf.xxx[x] = tf.keep(xxx[x].clone())
        }
        buf.aaa = {}
        for (var x in aaa) {
            buf.aaa[x] = tf.keep(aaa[x].clone())
        }
        buf.bbb = {}
        for (var x in bbb) {
            buf.bbb[x] = tf.keep(bbb[x].clone())
        }
        buf.hhh = tf.keep(hhh.clone())
    }

    function loadStat(name) {
        var buf = ctxBuf[name]
        for (var x in buf.xxx) {
            xxx[x] = buf.xxx[x].clone()
        }
        for (var x in buf.aaa) {
            aaa[x] = buf.aaa[x].clone()
        }
        for (var x in buf.bbb) {
            bbb[x] = buf.bbb[x].clone()
        }
        hhh = buf.hhh.clone()
        ctxNow = name
        return buf.out.slice()
    }

    function run(ctx) {
        var x = tf.tidy(() => {
            var ctxStr = ''
            for (var s of ctx)
                ctxStr += itos[s]
            ctxNow = ctxStr
            // console.log('run', ctxStr)

            var x = gParam.emb.weight.slice(ctx[ctx.length - 1], 1).squeeze()

            for (var i = 0; i < 6; i++) {
                x = LayerNorm(x, gParam.blocks[i].ln1)
                x = x.add(SA(x, gParam.blocks[i].att, i + '.att'))
                x = LayerNorm(x, gParam.blocks[i].ln2)
                x = x.add(FF(x, gParam.blocks[i].ffn, i + '.ffn'))
            }
            x = LayerNorm(x, gParam.ln_out)

            x = x.expandDims(1)
            var hk = gParam.head_k.weight.matMul(x).squeeze().expandDims(0)

            if (hhh === null) {
                hhh = tf.keep(hk)
            } else {
                var hh = hhh.clone()
                hhh.dispose()
                if (hh.shape[0] >= ctx_len) {
                    hhh = tf.keep(hh.slice(1, -1).concat(hk))
                } else {
                    hhh = tf.keep(hh.concat(hk))
                }
            }

            var q = gParam.head_q.weight.matMul(x)
            var c = hhh.matMul(q).div(256).dataSync()

            x = gParam.head.weight.matMul(x)
            x = x.dataSync()
            
            var i_delta = ctx.length - c.length
            for (var i = 0; i < c.length; i++)
                x[ctx[i + i_delta]] += c[i]

            return x
        })
        // console.log(x)
        // console.log(c)
        // console.log(tf.memory())
        return x
    }

    var WRITE_EACH_LENGTH = 256
    var ctx = [1]
    var ctxBuf = {}
    var ctxNow = ''

    function asyncWriteOne(iter = 0) {
        if (iter < WRITE_EACH_LENGTH) {
            document.getElementById("loading").innerHTML = 'Generating...'
            setTimeout(() => {
                var ctxStr = ''
                for (var s of ctx)
                    ctxStr += itos[s]

                var out
                if (ctxStr in ctxBuf) {
                    // console.log('find', ctxStr)
                    out = loadStat(ctxStr)
                } else {
                    out = run(ctx)
                    if (iter == WRITE_EACH_LENGTH - 1)
                        saveStat(out, ctxStr)
                }

                var indexed = Array.from(Array(out.length).keys()).sort((a, b) => out[a] > out[b] ? -1 : (out[b] > out[a]) | 0)

                // var result = indexed[0]

                var sum_exp = 0
                for (var i = 0; i < out.length; i++) {
                    out[i] = Math.exp(out[i])
                    sum_exp += out[i]
                }
                var ran = Math.random() * 0.7
                var i = 0
                while (true) {
                    // console.log(ran, i, out[indexed[i]] / sum_exp)
                    ran -= out[indexed[i]] / sum_exp
                    if (ran > 0)
                        i += 1
                    else
                        break
                }
                var result = indexed[i]

                addText(itos[result])
                ctx.push(result)

                asyncWriteOne(iter + 1);
            }, 0);
        } else {
            showBtn()
        }
    }

    var ANALYZE_LENGTH = 0

    function asyncAnalyze(iter = 1, callback) {
        if (iter < ANALYZE_LENGTH) {
            setTimeout(() => {
                var ccc = ctx.slice(0, iter)

                var ctxStr = ''
                for (var s of ccc)
                    ctxStr += itos[s]

                if (ctxStr in ctxBuf) {
                    // console.log('find', ctxStr)
                    loadStat(ctxStr)
                } else {
                    // console.log('ANALYZE', ctxStr)
                    var out = run(ccc)
                    if (iter == ANALYZE_LENGTH - 1)
                        saveStat(out, ctxStr)
                }
                document.getElementById("loading").innerHTML = 'Analyzing ' + Math.round(iter / ANALYZE_LENGTH * 100) + "% ... (this can be parallelized for 100x speedup but I haven't coded it yet)"
                asyncAnalyze(iter + 1, callback);
            }, 0);
        } else {
            callback()
        }
    }

    async function write(rawCtx) {
        if (LOADED_PARAMS == N_PARAMS) {
            clearStat()
            hideBtn()
            document.getElementById("loading").innerHTML = 'Generating...'
            var logArea = document.getElementById("logArea")
            logArea.value += '\n\n'
            var context = ''

            if (rawCtx == '') {
                context = '\n'
            } else {
                // var endsNewLine = false
                // var rr = rawCtx.split('\n')
                // if (rr[rr.length - 1].trim() == '')
                //     endsNewLine = true

                // rawCtx = rawCtx.trim()
                // rawCtx = rawCtx.replaceAll('\r\n', '\n').replaceAll('　', ' ')
                // rawCtx = rawCtx.split('\n')
                // for (var i = 0; i < rawCtx.length; i++) {
                //     var ss = rawCtx[i].trim()
                //     if (ss.length > 0)
                //         context += '\n' + ss
                // }
                // if (endsNewLine)
                //     context = context + '\n'
                context = '\n' + rawCtx
            }
            if (context.length > ctx_len) {
                context = context.substr(context.length - ctx_len)
            }            
            // console.log(context)
            ctx = Array.prototype.map.call(context, x => {
                if (x in stoi)
                    return stoi[x];
                else {
                    return -1;
                }
            })
            var badindex = ctx.indexOf(-1)
            if (badindex >= 0) {
                document.getElementById("loading").innerHTML = 'Out-of-vocabulary token [' + context[badindex] + '] - please fix it.'
                setTimeout(() => {
                    showBtn()
                }, 1000)
                return
            }

            var ctxStr = ''
            for (var s of ctx)
                ctxStr += itos[s]
            if (ctxStr in ctxBuf) {
                // console.log('find', ctxStr)
                loadStat(ctxStr)
                asyncWriteOne()
            } else if (ctxStr.slice(0, -1) in ctxBuf) {
                ANALYZE_LENGTH = ctx.length
                asyncAnalyze(ctx.length - 1, asyncWriteOne)
            } else {
                ANALYZE_LENGTH = ctx.length
                asyncAnalyze(1, asyncWriteOne)
            }
        }
    }

    //=================================================================================================

    var textArea = document.getElementById("textArea")
    var logArea = document.getElementById("logArea")

    let txt_stored = localStorage.getItem('txt_stored');
    if (txt_stored)
        textArea.value = txt_stored

    let log_stored = localStorage.getItem('log_stored');
    if (log_stored)
        logArea.value = log_stored

    let lastGeneratePosition = -1
    textArea.scrollTop = textArea.scrollHeight;
    logArea.scrollTop = logArea.scrollHeight;

    function addText(d) {
        textArea.value += d
        textArea.scrollTop = textArea.scrollHeight;

        logArea.value += d
        logArea.scrollTop = logArea.scrollHeight;

        localStorage.setItem('txt_stored', textArea.value);
        localStorage.setItem('log_stored', logArea.value);
    }

    textArea.onchange = function (e) {
        localStorage.setItem('txt_stored', textArea.value);
    }
    textArea.oninput = function (e) {
        localStorage.setItem('txt_stored', textArea.value);
    }
    logArea.onchange = function (e) {
        localStorage.setItem('log_stored', logArea.value);
    }
    logArea.oninput = function (e) {
        localStorage.setItem('log_stored', logArea.value);
    }

    function sendText() {
        if (document.getElementById("loading").style.display != "none")
            return
        let txt = document.getElementById("textArea").value
        lastGeneratePosition = txt.length
        // let msg = txt.substr(Math.max(0, txt.length - 767))
        // console.log(msg)
        write(txt)
    }

    function rewriteText() {
        let txt = document.getElementById("textArea").value
        if (lastGeneratePosition == -1) {
            if (txt.length == 0)
                lastGeneratePosition = 0
        }
        if (lastGeneratePosition != -1) {
            txt = txt.substr(0, lastGeneratePosition)
            document.getElementById("textArea").value = txt
        }
        sendText()
    }

    document.onkeydown = function (event) {
        if (event.altKey) {
            if (event.keyCode === 81) {
                sendText()
                return false
            }
            if (event.keyCode === 69) {
                rewriteText()
                return false
            }
        }
    }
</script>

</html>
